"""
Activations
"""

import math
import torch
import torch.nn as nn

from torch.nn import functional as F


class DyReLU(nn.Module):
    def __init__(self, channels, reduction=4, k=2, conv_type='2d'):
        super(DyReLU, self).__init__()
        self.channels = channels
        self.k = k
        self.conv_type = conv_type
        assert self.conv_type in ['1d', '2d']

        self.fc1 = nn.Linear(channels, channels // reduction)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(channels // reduction, 2*k)
        self.sigmoid = nn.Sigmoid()

        self.register_buffer('lambdas', torch.Tensor([1.]*k + [0.5]*k).float())
        self.register_buffer('init_v', torch.Tensor([1.] + [0.]*(2*k - 1)).float())

    def get_relu_coefs(self, x):
        theta = torch.mean(x, axis=-1)
        if self.conv_type == '2d':
            theta = torch.mean(theta, axis=-1)
        theta = self.fc1(theta)
        theta = self.relu(theta)
        theta = self.fc2(theta)
        theta = 2 * self.sigmoid(theta) - 1
        return theta

    def forward(self, x):
        raise NotImplementedError


class DyReLUA(DyReLU):
    def __init__(self, channels, reduction=4, k=2, conv_type='2d'):
        super(DyReLUA, self).__init__(channels, reduction, k, conv_type)
        self.fc2 = nn.Linear(channels // reduction, 2*k)

    def forward(self, x):
        assert x.shape[1] == self.channels
        theta = self.get_relu_coefs(x)

        relu_coefs = theta.view(-1, 2*self.k) * self.lambdas + self.init_v
        # BxCxL -> LxCxBx1
        x_perm = x.transpose(0, -1).unsqueeze(-1)
        output = x_perm * relu_coefs[:, :self.k] + relu_coefs[:, self.k:]
        # LxCxBx2 -> BxCxL
        result = torch.max(output, dim=-1)[0].transpose(0, -1)

        return result


class DyReLUB(DyReLU):
    def __init__(self, channels, reduction=4, k=2, conv_type='2d'):
        super(DyReLUB, self).__init__(channels, reduction, k, conv_type)
        self.fc2 = nn.Linear(channels // reduction, 2*k*channels)

    def forward(self, x):
        assert x.shape[1] == self.channels
        theta = self.get_relu_coefs(x)

        relu_coefs = theta.view(-1, self.channels, 2*self.k) * self.lambdas + self.init_v

        if self.conv_type == '1d':
            # BxCxL -> LxBxCx1
            x_perm = x.permute(2, 0, 1).unsqueeze(-1)
            output = x_perm * relu_coefs[:, :, :self.k] + relu_coefs[:, :, self.k:]
            # LxBxCx2 -> BxCxL
            result = torch.max(output, dim=-1)[0].permute(1, 2, 0)

        elif self.conv_type == '2d':
            # BxCxHxW -> HxWxBxCx1
            x_perm = x.permute(2, 3, 0, 1).unsqueeze(-1)
            output = x_perm * relu_coefs[:, :, :self.k] + relu_coefs[:, :, self.k:]
            # HxWxBxCx2 -> BxCxHxW
            result = torch.max(output, dim=-1)[0].permute(2, 3, 0, 1)

        return result




def gelu(x, inplace: bool = False):
    p_out = 0.5 * (1 + torch.erf(x / math.sqrt(2)))  # 概率
    return p_out * x


class GELU(nn.Module):
    def __init__(self, inplace: bool = False):
        super().__init__()
        self.inplace = inplace

    def forward(self, x):
        return gelu(x)


def gclu_tanh(x, inplace: bool = False):
    p_out = 0.5 * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))  # 概率
    weight = torch.where(p_out < 0.5, p_out, 2 - p_out)
    return weight * x


class GCLUTanh(nn.Module):
    def __init__(self, inplace: bool = False):
        super().__init__()
        self.inplace = inplace

    def forward(self, x):
        return gclu_tanh(x)


def quick_gclu(x, inplace: bool = False):
    # 使用cdf的近似形式计算概率
    p_out = torch.sigmoid(1.702 * x)  # 概率
    weight = torch.where(p_out < 0.5, p_out, 2 - p_out)
    return weight * x


class QuickGCLU(nn.Module):
    def __init__(self, inplace: bool = False):
        super().__init__()
        self.inplace = inplace

    def forward(self, x):
        return quick_gclu(x)


def hgelu(x, inplace: bool = False):
    # 保证大于0的部分导数不会为0
    p_out = 0.5 * (1 + torch.erf(x / math.sqrt(2)))  # 概率
    # 残差学习
    weight = torch.where(p_out < 0.5, p_out, 2 - p_out)

    return weight * x


class HGELU(nn.Module):
    """
    并行实现版本
    """
    def __init__(self, inplace: bool = False):
        super().__init__()
        self.inplace = inplace

    def forward(self, x):
        return hgelu(x)


class SequecialHGELU(nn.Module):
    """
    串行实现版本，包含 ReLU层
    """
    def __init__(self, inplace: bool = True):
        super().__init__()
        self.inplace = inplace
        self.relu = nn.ReLU(inplace=inplace)

    def forward(self, x):
        p_out = 0.5 * (1 + torch.erf(x / math.sqrt(2)))  # 概率
        # 残差学习
        weight = torch.where(p_out < 0.5, p_out, 1 - p_out)

        out = x + weight * x
        return self.relu(out)


class SequecialHGELUV2(nn.Module):
    """
    串行实现版本，不包含 ReLU层
    """
    def __init__(self, inplace: bool = True):
        super().__init__()
        self.inplace = inplace

    def forward(self, x):
        p_out = 0.5 * (1 + torch.erf(x / math.sqrt(2)))  # 概率
        # 残差学习
        weight = torch.where(p_out < 0.5, p_out, 1 - p_out)

        return weight * x * 2


class SequecialHGELUV3(nn.Module):
    """
    串行实现版本，不包含 ReLU层
    """
    def __init__(
            self,
            num_features: int,
            eps: float = 1e-5,
    ) -> None:
        super().__init__()
        self.mu = nn.Parameter(torch.zeros(1, num_features))
        self.log_var = nn.Parameter(torch.zeros(1, num_features))
        self.eps = eps

    def forward(self, x):
        # 计算标准差
        std = torch.exp(0.5 * self.log_var)
        # 归一化
        x_dim = x.ndim
        if x_dim == 2:
            norm_out = (x - self.mu) / (std + self.eps)
        elif x_dim == 4:
            norm_out = (x - self.mu.reshape(1, -1, 1, 1)) / (std.reshape(1, -1, 1, 1) + self.eps)
        # 计算概率
        p_out = 0.5 * (1 + torch.erf(norm_out / math.sqrt(2)))
        # 残差学习
        weight = torch.where(p_out < 0.5, p_out, 1 - p_out)

        return weight * x * 2



class SequecialHGELUV4(nn.Module):
    """
    串行实现版本，不包含 ReLU层
    """
    def __init__(
            self,
            num_features: int,
            eps: float = 1e-5,
    ) -> None:
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc21 = nn.Linear(num_features, num_features)
        self.fc22 = nn.Linear(num_features, num_features)
        self.eps = eps

    def encode(self, x):
        return self.fc21(x), self.fc22(x)

    def forward(self, x):
        mu, log_var = self.encode(torch.flatten(self.avg_pool(x), 1))
        # 计算标准差
        std = torch.exp(0.5 * log_var)
        # 归一化
        x_dim = x.ndim
        if x_dim == 2:
            norm_out = (x - mu) / (std + self.eps)
        elif x_dim == 4:
            b, c, _, _ = x.size()
            norm_out = (x - mu.reshape(b, c, 1, 1)) / (std.reshape(b, c, 1, 1) + self.eps)
        # 计算概率
        p_out = 0.5 * (1 + torch.erf(norm_out / math.sqrt(2)))
        # 残差学习
        weight = torch.where(p_out < 0.5, p_out, 1 - p_out)

        return weight * x


class SequecialHGELUV4B(nn.Module):
    """
    串行实现版本，不包含 ReLU层
    """
    def __init__(
            self,
            num_features: int,
            eps: float = 1e-5,
            r: int = 16,
            dropout_p: float = 0,
    ) -> None:
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(num_features, num_features//r)
        self.dropout = nn.Dropout(dropout_p) if dropout_p > 0. else nn.Identity()
        self.fc21 = nn.Linear(num_features//r, num_features)
        self.fc22 = nn.Linear(num_features//r, num_features)
        self.eps = eps

    def encode(self, x):
        x = self.fc1(x)
        x = self.dropout(x)
        return self.fc21(x), self.fc22(x)

    def forward(self, x):
        mu, log_var = self.encode(torch.flatten(self.avg_pool(x), 1))
        # 计算标准差
        std = torch.exp(0.5 * log_var)
        # 归一化
        x_dim = x.ndim
        if x_dim == 2:
            norm_out = (x - mu) / (std + self.eps)
        elif x_dim == 4:
            b, c, _, _ = x.size()
            norm_out = (x - mu.reshape(b, c, 1, 1)) / (std.reshape(b, c, 1, 1) + self.eps)
        # 计算概率
        p_out = 0.5 * (1 + torch.erf(norm_out / math.sqrt(2)))
        # 残差学习
        weight = torch.where(p_out < 0.5, p_out, 1 - p_out)

        return weight * x


class SequecialHGELUV4C(nn.Module):
    def __init__(
            self,
            num_features: int,
            eps: float = 1e-5,
            r: int = 16,
            dropout_p: float = 0,
    ) -> None:
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(num_features, num_features//r, 1)
        self.fc21 = nn.Conv2d(num_features//r, num_features, 1)
        self.fc22 = nn.Conv2d(num_features//r, num_features, 1)
        self.eps = eps

    def encode(self, x):
        x = self.fc1(x)
        return self.fc21(x), self.fc22(x)

    def forward(self, x):
        mu, log_var = self.encode(self.avg_pool(x))
        # 计算标准差
        std = torch.exp(0.5 * log_var)
        # 归一化
        norm_out = (x - mu) / (std + self.eps)
        # 计算概率
        p_out = 0.5 * (1 + torch.erf(norm_out / math.sqrt(2)))
        # 残差学习
        weight = torch.where(p_out < 0.5, p_out, 1 - p_out)

        return weight * x


class SequecialHGELUV5(nn.Module):
    """
    串行实现版本，分组卷积实现
    """
    def __init__(
            self,
            num_features: int,
            eps: float = 1e-5,
    ) -> None:
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(num_features, num_features, kernel_size=1, stride=1, groups=num_features)
        self.fc2 = nn.Conv2d(num_features, num_features, kernel_size=1, stride=1, groups=num_features)
        self.eps = eps

    def encode(self, x):
        return self.fc1(x), self.fc2(x)

    def forward(self, x):
        mu, log_var = self.encode(self.avg_pool(x))
        # 计算标准差
        std = torch.exp(0.5 * log_var)
        # 归一化
        norm_out = (x - mu) / (std + self.eps)
        # 计算概率
        p_out = 0.5 * (1 + torch.erf(norm_out / math.sqrt(2)))
        # 残差学习
        weight = torch.where(p_out < 0.5, p_out, 1 - p_out)

        return weight * x


class SequecialHGELUV6(nn.Module):
    """
    串行实现版本，分组卷积实现
    """
    def __init__(
            self,
            num_features: int,
            eps: float = 1e-5,
    ) -> None:
        super().__init__()
        self.fc1 = nn.Conv2d(num_features, num_features, kernel_size=1, stride=1, padding=0, groups=num_features)
        self.fc2 = nn.Conv2d(num_features, num_features, kernel_size=1, stride=1, padding=0, groups=num_features)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.eps = eps

    def encode(self, x):
        return self.avg_pool(self.fc1(x)), self.avg_pool(self.fc2(x))

    def forward(self, x):
        mu, log_var = self.encode(x)
        # 计算标准差
        std = torch.exp(0.5 * log_var)
        # 归一化
        norm_out = (x - mu) / (std + self.eps)
        # 计算概率
        p_out = 0.5 * (1 + torch.erf(norm_out / math.sqrt(2)))
        # 残差学习
        weight = torch.where(p_out < 0.5, p_out, 1 - p_out)

        return weight * x


class SequecialHGELUV4D(nn.Module):
    """
    仅使用全连接层
    """
    def __init__(
            self,
            num_features: int,
            eps: float = 1e-5,
            r: int = 16,
            dropout_p: float = 0,
    ) -> None:
        super().__init__()
        self.fc = nn.Conv2d(num_features, num_features, kernel_size=1)

    def forward(self, x):
        return self.fc(x)



import torch
import torch.nn as nn
import torch.nn.functional as F

class ADReLU(nn.Module):
    def __init__(self, in_channels=None, dk=8, input_size=None):
        super().__init__()
        self.dk = dk
  
        self.eps = 1e-5
        self.min_channels = 8 
        self.group_size = 4 

        if in_channels is not None:
            self.in_channels = in_channels
            self._initialize_layers()
        elif input_size is not None:
            self.in_channels = input_size[0]
            self._initialize_layers()
        else:
            raise ValueError("Either in_channels or input_size must be provided")

    def _initialize_layers(self):
       
        self.qkv_conv = nn.Sequential(
            nn.Conv2d(self.in_channels, self.in_channels, 1, 
                     groups=min(self.group_size, self.in_channels)), 
            nn.GroupNorm(min(self.group_size, self.in_channels), self.in_channels),  
            nn.Conv2d(self.in_channels, 3 * self.dk, 1)
        )
        
  
        self.proj = nn.Sequential(
            nn.Conv2d(self.dk, self.dk, 3, padding=1, groups=self.dk),
            nn.GroupNorm(min(self.group_size, self.dk), self.dk),  
            nn.Conv2d(self.dk, self.in_channels, 1),
            nn.GroupNorm(min(self.group_size, self.in_channels), self.in_channels)  
        )


    def forward(self, x):
        B, C, H, W = x.shape
        x_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        x_fused = 0.7 * x + 0.3 * x_pool 

        # Efficient QKV with normalization
        # qkv = self.qkv_conv(x)
        qkv = self.qkv_conv(x_fused)
        q, k, v = qkv.chunk(3, dim=1)
        
        q = q / (q.norm(dim=1, keepdim=True) + self.eps)
        k = k / (k.norm(dim=1, keepdim=True) + self.eps)
        
        # Simplified attention
        attn = torch.sigmoid((q * k).sum(dim=1, keepdim=True))
        
        attended = attn * v

        # Projection with residual
        tau = self.proj(attended)
        return torch.max(x, tau)
